#########
## Nature Human Behaviour 
## Leading Countries in Global Science Increasingly Receive More Citations than Other Countries Despite Doing Similar Research.
## https://doi.org/10.1038/s41562-022-01351-5
## Harvard Dataverse (Code and Metadata): https://doi.org/10.7910/DVN/WCOINR 
## Step 0F
## Data: Data_20210905
#########

# At a terminal, run: 
# '''Note: '$1' is the discipline ID that you pass in.'''
# '''Note: the second parameter is fed in as either 'english_only' or 'all' '''
# '''where 'english_only' are the main results and 'all' are the SI results. '''
# ml python/3.6.1
# ml py-ipython/6.1.0_py36 python/3.6.1 py-scipy/1.1.0_py36 py-scikit-learn/0.19.1_py36 py-pandas/0.23.0_py36 gcc/10.1.0 py-pytorch/1.4.0_py36
# ml py-numpy/1.17.2_py36
# export PYTHONPATH=$GROUP_HOME/python/lib/python3.6/site-packages:$PYTHONPATH
# srun python3 -u Step_X0F_Python3_RR_MAG_Journal_Censored_Yearly_RAKE_and_GoogleAPI_NLLDA.py "$1" "english_only"


#############################
### Inputs
#############################
import sys
discipline = str(sys.argv[1])
language = str(sys.argv[2]) #"english_only" or "all"

#############################
### Time Start
#############################
import time 
start_time = time.time()

#############################
### Modules
#############################
import sys
from collections import Counter
import random
import bz2 
import pickle
import pandas as pd
import numpy as np
import gc
import itertools
import os 
import string
import re
from nltk.stem import PorterStemmer
from pyathena import connect
import boto3
import smart_open

from Python_Class_LLDA import LLDA

#############################
### Functions
#############################

def reduce_mem_usage(df):
    """ 
    iterate through all the columns of a dataframe and 
    modify the data type to reduce memory usage.        
    """
    start_mem = df.memory_usage().sum() / 1024**2
   
    for col in df.columns:
        col_type = df[col].dtype
        
        if col_type != object:
            c_min = df[col].min()
            c_max = df[col].max()
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max <\
                  np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max <\
                   np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max <\
                   np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max <\
                   np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)  
            else:
                if c_min > np.finfo(np.float16).min and c_max <\
                   np.finfo(np.float16).max:
                    df[col] = df[col].astype(np.float16)
                elif c_min > np.finfo(np.float32).min and c_max <\
                   np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)
        else:
            next    
    return df

def fastAthenaQuery(input_sql, input_chunksize = None):

	#This function will return a stream of the s3 file.
	# MUCH faster than PyAthena reading a few rows at a time via the API
	""":Return: a Pandas DataFrame of results from a `sql` query executed against AWS Athena."""
	cursor.execute(input_sql)
	#The s3_path should be of the format: '<bucket_name>/<file_path_inside_the_bucket>'
	#This is the full path with credentials:
	complete_s3_path = 'REDACTED' + cursor.output_location.split("REDACTED")[1]

	if input_chunksize == None:
		outputfile = pd.read_csv(smart_open.smart_open(complete_s3_path))
	else:
		outputfile = pd.read_csv(smart_open.smart_open(complete_s3_path),chunksize=input_chunksize)

	# Delete the file from the s3 bucket
	s3 = boto3.resource("REDACTED",
		aws_access_key_id='REDACTED',
		aws_secret_access_key= 'REDACTED')
	obj = s3.Object("REDACTED",
		complete_s3_path)
	obj.delete()

	return outputfile


###############################
### Connections
###############################

cursor = connect(aws_access_key_id='REDACTED',
                 aws_secret_access_key='REDACTED',
                 s3_staging_dir='REDACTED',
                 region_name='REDACTED').cursor()

conn = connect(aws_access_key_id='REDACTED',
                 aws_secret_access_key='REDACTED',
                 s3_staging_dir='REDACTED',
                 region_name='REDACTED')


#############################
### Journal Censoring
#############################

journal_df_coverage_1980_1990 = pd.read_csv("OUTPUT_Python_MAG_Journal_Coverage_From_1980_to_1990.csv.gz").query("fieldofstudyid==@discipline")[["journalid","coverage_over_time"]]
journal_df_coverage_1990_2000 = pd.read_csv("OUTPUT_Python_MAG_Journal_Coverage_From_1990_to_2000.csv.gz").query("fieldofstudyid==@discipline")[["journalid","coverage_over_time"]]
journal_df_coverage_2000_2010 = pd.read_csv("OUTPUT_Python_MAG_Journal_Coverage_From_2000_to_2010.csv.gz").query("fieldofstudyid==@discipline")[["journalid","coverage_over_time"]]
journal_df_coverage_2010_2017 = pd.read_csv("OUTPUT_Python_MAG_Journal_Coverage_From_2010_to_2017.csv.gz").query("fieldofstudyid==@discipline")[["journalid","coverage_over_time"]]

# [2000-2017]
journal_df_coverage_since_2000 = journal_df_coverage_2000_2010.append(journal_df_coverage_2010_2017)

# [1980-2017]
journal_df_coverage_since_1980 = journal_df_coverage_1990_2000.append(journal_df_coverage_since_2000)
journal_df_coverage_since_1980 = journal_df_coverage_since_1980.append(journal_df_coverage_1980_1990)

# Combine Together <=1980 and <=2000
journal_df_coverage_since_1980 = journal_df_coverage_since_1980.groupby(["journalid"]).agg({"coverage_over_time":"sum"}).reset_index()
journal_df_coverage_since_1980 = journal_df_coverage_since_1980.sort_values(["coverage_over_time"],ascending=False)
journal_df_coverage_since_2000 = journal_df_coverage_since_2000.groupby(["journalid"]).agg({"coverage_over_time":"sum"}).reset_index()

journal_df_coverage_since_1980 = journal_df_coverage_since_1980[journal_df_coverage_since_1980.coverage_over_time == journal_df_coverage_since_1980.coverage_over_time.max()]["journalid"].reset_index(drop=True)
journal_df_coverage_since_2000 = journal_df_coverage_since_2000[journal_df_coverage_since_2000.coverage_over_time == journal_df_coverage_since_2000.coverage_over_time.max()]["journalid"].reset_index(drop=True)

check_paperid_with_journal = '''
select true as returnTrue where EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'temp_paperid_with_journal_FIELDIDHERE');
'''

drop_paperid_with_journal = '''
drop table if exists mag_staging.temp_paperid_with_journal_FIELDIDHERE
'''

create_table_paperid_with_journal = '''
create table mag_staging.temp_paperid_with_journal_FIELDIDHERE
	WITH (
		format = 'TEXTFILE',
		field_delimiter = '\t'
	  )
as
SELECT ab.paperid as PaperID,
         ab.journalid as JournalID,
         ab.full_year as Year
FROM mag_data.journals c
JOIN 
    (SELECT a.paperid,
         a.full_year,
         a.journalid
    FROM mag_data.papers a
    JOIN 
        (SELECT paperid
        FROM mag_data.paperfieldsofstudy
        WHERE fieldofstudyid=FIELDIDHERE) b
            ON a.paperid = b.paperid) ab
        ON c.journalid = ab.journalid
WHERE ab.full_year>=1980 
'''

query_paperid_with_journal = '''
SELECT * FROM mag_staging.temp_paperid_with_journal_FIELDIDHERE
'''

if pd.read_sql(check_paperid_with_journal.replace("FIELDIDHERE",discipline), conn)["returnTrue"].empty == True:
	cursor.execute(create_table_paperid_with_journal.replace("FIELDIDHERE",discipline))

df_paperid_with_journal = fastAthenaQuery(query_paperid_with_journal.replace("FIELDIDHERE",discipline),None)

cursor.execute(drop_paperid_with_journal.replace("FIELDIDHERE",discipline))


#############################
### Check if NL-LDA Models Exists
#############################

if language=="english_only" and os.path.isfile("OUTPUT_Python_MAG_Yearly_NLLDA_Dict_Corpus_RAKE_and_GoogleAPI_EnglishOnly_"+str(discipline)+".pbz2"):
	print('english_only')
	 

if language=="all" and os.path.isfile("OUTPUT_Python_MAG_Yearly_NLLDA_Dict_Corpus_RAKE_and_GoogleAPI_All_"+str(discipline)+".pbz2"):
	print("all")
	 

#############################
### Read in Files | Metadata
#############################

Field_MetaData_Filename = "OUTPUT_Python_MAG_Field_MetaData_Dict_"+str(discipline)+".pbz2"

f = bz2.BZ2File(Field_MetaData_Filename, 'rb')
dict_grid_labels = pickle.load(f)
dict_year = pickle.load(f)
#dict_coauthor_edgelist = pickle.load(f)
#dict_mobility_edgelist = pickle.load(f)
#dict_citation_edgelist = pickle.load(f)
#dict_citation_year_edgelist = pickle.load(f)

df_grid_labels = pd.DataFrame(dict_grid_labels)
del dict_grid_labels
gc.collect()

df_year = pd.DataFrame(dict_year)
del dict_year
gc.collect()

#############################
### Read in Files | Corpus
#############################

if language=="english_only":
	try:
		Field_Data_Filename_Corpus = "OUTPUT_Python_MAG_Field_Corpus_RAKE_and_GoogleAPI_"+str(discipline)+".pbz2"
		f = bz2.BZ2File(Field_Data_Filename_Corpus, 'rb')
		corpus = pickle.load(f)
		corpus = corpus.query("Is_English=='en'")[["paperid","Abstract"]].set_index("paperid").to_dict()["Abstract"]
	except:
		try:
			Field_Data_Filename_Corpus = "TEMP_English_OUTPUT_Python_MAG_Field_Corpus_RAKE_and_GoogleAPI_"+str(discipline)+".pbz2"
			f = bz2.BZ2File(Field_Data_Filename_Corpus, 'rb')
			corpus = pickle.load(f)
			corpus = corpus[["paperid","Abstract"]].set_index("paperid").to_dict()["Abstract"]
		except:
			print("------Error English Only Corpus Missing-------")
			exit()


if language=="all":
	try:
		Field_Data_Filename_Corpus = "OUTPUT_Python_MAG_Field_Corpus_RAKE_and_GoogleAPI_"+str(discipline)+".pbz2"
		f = bz2.BZ2File(Field_Data_Filename_Corpus, 'rb')
		corpus = pickle.load(f)
		corpus = corpus[["paperid","Abstract"]].set_index("paperid").to_dict()["Abstract"]
	except:
		print("Error All Corpus Missing")

#############################
## GRID Database with GRID from Field
#############################

df_grid_database = pd.read_csv("addresses.csv")

df_grid_labels = pd.merge(df_grid_labels,df_grid_database[["grid_id","country"]].rename(columns={"grid_id":"gridid"}),on=["gridid"],how="left")
df_grid_labels["labels"] = df_grid_labels["country"].str.lstrip().str.rstrip().str.upper().str.replace(" ","_")
df_grid_labels = df_grid_labels.dropna()
df_grid_labels = df_grid_labels[["paperid","labels"]]

#############################
## Labeled LDA | Corpora by Year and Labels by Year
#############################

############
## RAKE Corpus
############
df_year["year"] = df_year["year"].fillna(0)
df_year["year"] = df_year["year"].astype(int)

for journal_censor_ in ["Since_1980"]:

	if journal_censor_ == "Since_1980":
		df_year_censored = pd.merge(pd.merge(pd.DataFrame(journal_df_coverage_since_1980),df_paperid_with_journal,on=["journalid"]),df_year,on=["paperid","year"],how="left")
	if journal_censor_ == "Since_2000":
		df_year_censored = pd.merge(pd.merge(pd.DataFrame(journal_df_coverage_since_2000),df_paperid_with_journal,on=["journalid"]),df_year,on=["paperid","year"],how="left")

	year_dict = pd.Series(df_year_censored.year.values,index=df_year_censored.paperid).to_dict()
	list_of_years = list(set(year_dict.values()))
	#list_of_years.pop(0)

	df_labels = df_grid_labels.groupby('paperid')["labels"].apply(lambda x: " ".join(x)).reset_index(name="Labels")
	labels_dict = pd.Series(df_labels.Labels.values,index=df_labels.paperid).to_dict()

	Yearly_Dict_of_Corpora = {years_:{} for years_ in list_of_years}
	Yearly_Dict_of_Labels = {years_:{} for years_ in list_of_years}

	for paperid_, yearid_ in year_dict.items():
		try:
			if corpus[paperid_]!=[] and labels_dict[paperid_]!="":
				Yearly_Dict_of_Corpora[yearid_].update({paperid_:corpus[paperid_]})
				Yearly_Dict_of_Labels[yearid_].update({paperid_:labels_dict[paperid_]})
		except:
			next

	def clean_ngrams(x):
		x = re.sub(r'[^\w\s]', '', x)
		#x = x.replace('.','').replace(',','')
		if x.isdigit()==True:
			return ''
		x = x.lstrip().rstrip()
		if len(x)>1:
			if x[-1]=='s':
				x = x[:-1]
		return x

	Yearly_Dict_of_Corpora = {years_:{paperid_:[clean_ngrams(str(term_)) for term_ in abstract_] for paperid_, abstract_ in year_corpora.items()} for years_,year_corpora in Yearly_Dict_of_Corpora.items()}

	Yearly_Dict_of_Corpora = {years_:{paperid_:[term_ for term_ in abstract_ if len(term_)>1] for paperid_, abstract_ in year_corpora.items()} for years_,year_corpora in Yearly_Dict_of_Corpora.items()}


	##############################
	### Labeled LDA
	##############################
	NLLDA_Min_Year = 1980 
	NLLDA_Max_Year = 2017
		
	Dictionary_of_NLLDA = {}

	for beta in [0.1,0.5,0.9]:
		Dictionary_of_NLLDA[beta] = {}

		alpha = 0.1 #0.1 # originally 0.01 - prior weight of topic k in a document; usually the same for all topics; normally a number less than 1, e.g. 0.1, to prefer sparse topic distributions, i.e. few topics per document
		#beta = float(beta_number)  #0.1 # originally 0.001 - prior weight of word w in a topic; usually the same for all words; normally a number much less than 1, e.g. 0.001, to strongly prefer sparse word distributions, i.e. few words per topic
		iteration = 1000 # originally 100 - now 1,000 
		seed = None
		samplesize = 100

		for year_ in range(NLLDA_Min_Year,NLLDA_Max_Year+1,1):

			labels_list = [label_.split(" ") for label_ in Yearly_Dict_of_Labels[year_].values()]
			labels_set = list(set(list(itertools.chain.from_iterable(labels_list))))

			K = len(labels_set) # Number of labels

			NLLDA_Model_ = LLDA(K, alpha, beta)
			NLLDA_Model_.set_corpus(labels_set, Yearly_Dict_of_Corpora[year_].values(), labels_list)
			Dictionary_of_NLLDA[beta][year_] = NLLDA_Model_

	if language == "english_only":
		Yearly_NLLDA_Dict_Filename = "OUTPUT_Python_MAG_Journal_Censored_"+str(journal_censor_)+"_Yearly_NLLDA_Dict_Corpus_RAKE_and_GoogleAPI_EnglishOnly_"+str(discipline)+".pbz2"
	else:
		Yearly_NLLDA_Dict_Filename = "OUTPUT_Python_MAG_Journal_Censored_"+str(journal_censor_)+"_Yearly_NLLDA_Dict_Corpus_RAKE_and_GoogleAPI_All_"+str(discipline)+".pbz2"

	gc.disable() # Memory Issues When Pickling 
	with bz2.BZ2File(Yearly_NLLDA_Dict_Filename, 'w') as f:
		for beta in [0.1,0.5,0.9]:
			for year in range(NLLDA_Min_Year,NLLDA_Max_Year+1):
				try:
					print("beta "+str(beta)+" year "+str(year))
					pickle.dump(Dictionary_of_NLLDA[beta][year], f, protocol=2)
				except:
					pickle.dump({}, f, protocol=2)
	f.close()
	gc.enable()
